{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 神经网络的实践"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorDataset \n",
    "\n",
    "TensorDataset在PyTorch中是一个非常有用的数据集类,它可以将tensors打包成数据集以供模型训练。\n",
    "\n",
    "可以看到TensorDataset将输入的特征和标签打包成了数据集,可以用于后面的模型训练。\n",
    "\n",
    "它的主要作用有:\n",
    "\n",
    "- 将特征和标签组合成数据集\n",
    "\n",
    "- 支持打包多个tensor,如图像、标签等\n",
    "\n",
    "- 可以索引样例,方便读取\n",
    "\n",
    "- 兼容 DataLoader,可以批量读取\n",
    "\n",
    "TensorDataset使得tensor数据的组织和读取非常简洁高效。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataLoader\n",
    "\n",
    "在PyTorch中,DataLoader是用于数据读取的重要类,其主要作用和用法示例如下:\n",
    "\n",
    "1. 数据读取\n",
    "\n",
    "DataLoader实现了对Dataset按batch读取,支持多进程读,自动转为GPU等功能:\n",
    "\n",
    "\n",
    "2. 批处理\n",
    "\n",
    "可以指定batch大小,将数据分成批进行读取:\n",
    "\n",
    "\n",
    "3. 随机打乱\n",
    "\n",
    "设置shuffle=True可以按epoch随机打乱数据:\n",
    "\n",
    "4. 多进程加速\n",
    "\n",
    "设置num_workers启动多进程读取数据:\n",
    "\n",
    "5. 样本采样\n",
    "\n",
    "可以通过Sampler自定义从数据集中采样样本:\n",
    "\n",
    "总之,DataLoader提高了数据读取效率,是PyTorch中使用数据集的标准方式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化特征和标签张量\n",
    "features = torch.randn(100, 5) \n",
    "labels = torch.randint(0, 10, (100,))\n",
    "\n",
    "# 用TensorDataset打包特征和标签\n",
    "dataset = TensorDataset(features, labels)\n",
    "\n",
    "# 创建数据加载器\n",
    "dataloader = DataLoader(dataset, batch_size=8) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 0.5523,  0.2738,  0.1225, -0.1789, -0.8072]), tensor(2))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[-0.0114, -0.0916,  1.3767, -1.8570,  0.6364],\n",
      "        [ 0.1598, -0.3302,  0.4382,  1.2717,  1.1800],\n",
      "        [-0.5495,  0.5485,  0.0685, -1.4578,  1.5613],\n",
      "        [ 0.6499,  0.0888,  0.5256, -1.2434, -0.4251],\n",
      "        [-1.4292, -0.5635,  0.4385,  0.4239,  0.0719],\n",
      "        [ 0.3692,  0.3452,  0.5108,  0.5149, -0.3104],\n",
      "        [ 0.2049, -1.6116,  0.0566,  0.4828,  0.2168],\n",
      "        [ 0.1305, -1.4151, -0.8409, -0.5174,  1.1110]]), tensor([0, 8, 3, 3, 6, 1, 8, 8])]\n"
     ]
    }
   ],
   "source": [
    "for i in dataloader:\n",
    "    print(i)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([9, 2, 7, 0, 8, 6, 8, 5, 9, 5, 9, 1, 5, 1, 8, 5, 0, 0, 2, 6, 8, 9, 2, 9,\n",
       "        6, 9, 6, 9, 9, 5, 5, 6, 3, 6, 8, 8, 3, 7, 2, 4, 0, 6, 8, 4, 7, 7, 4, 6,\n",
       "        1, 3, 4, 5, 4, 6, 1, 9, 2, 2, 5, 3, 3, 1, 4, 7, 7, 2, 2, 0, 4, 3, 2, 6,\n",
       "        5, 7, 6, 5, 8, 0, 6, 6, 0, 6, 4, 0, 1, 0, 7, 6, 5, 9, 2, 9, 5, 3, 2, 2,\n",
       "        5, 6, 2, 0])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randint(0, 10, (100,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = torch.randint(low=0,high=3,size=(500,1),dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 多分类，二分类与回归的问题切换\n",
    "\n",
    "nn.MSELoss()\n",
    "\n",
    "nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义神经网路的架构\n",
    "class Model(nn.Module):\n",
    "    def __init__(self,in_features=10,out_features=2):\n",
    "        super(Model,self).__init__() #super(请查找这个类的父类，请使用找到的父类替换现在的类)\n",
    "        self.linear1 = nn.Linear(in_features,50,bias=True) #输入层不用写，这里是隐藏层的第一层\n",
    "        self.linear2 = nn.Linear(50,20,bias=True)\n",
    "        self.output = nn.Linear(20,out_features,bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        z1 = self.linear1(x)\n",
    "        z1_relu = torch.relu(z1)\n",
    "        z2 = self.linear2(z1_relu)\n",
    "        z2_relu = torch.relu(z2)\n",
    "        z3 = self.output(z2_relu)\n",
    "        #sigma3 = F.softmax(z3,dim=1)\n",
    "        return z3\n",
    "#实例化神经网络，调用优化算法需要的参数\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#确定数据、确定优先需要设置的值\n",
    "lr = 0.1\n",
    "gamma = 0.9\n",
    "epochs=5\n",
    "bs=50\n",
    "\n",
    "torch.manual_seed(420)\n",
    "X = torch.rand((500,20),dtype=torch.float32)\n",
    "Y= torch.randint(low=0,high=3,size=(500,1),dtype=torch.float32)\n",
    "\n",
    "data = TensorDataset(X,Y)\n",
    "batchdata = DataLoader(data, batch_size=bs, shuffle = True)\n",
    "\n",
    "input_ = X.shape[1] #特征的数目\n",
    "output_ = len(Y.unique()) #分类的数目\n",
    "\n",
    "torch.manual_seed(420)\n",
    "net = Model(in_features=input_, out_features=output_)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "opt = optim.SGD(net.parameters() , lr=lr , momentum = gamma) #动量参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Y.view(X.shape[0]).long()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0 , loss:1.0957516431808472\n",
      "epoch:0 , loss:1.1143020391464233\n",
      "epoch:0 , loss:1.1196856498718262\n",
      "epoch:0 , loss:1.094942569732666\n",
      "epoch:0 , loss:1.0956664085388184\n",
      "epoch:0 , loss:1.0987811088562012\n",
      "epoch:0 , loss:1.0922518968582153\n",
      "epoch:0 , loss:1.101608395576477\n",
      "epoch:0 , loss:1.0952848196029663\n",
      "epoch:0 , loss:1.1202633380889893\n",
      "epoch:1 , loss:1.0993648767471313\n",
      "epoch:1 , loss:1.0835442543029785\n",
      "epoch:1 , loss:1.1072036027908325\n",
      "epoch:1 , loss:1.0957679748535156\n",
      "epoch:1 , loss:1.0953600406646729\n",
      "epoch:1 , loss:1.0984840393066406\n",
      "epoch:1 , loss:1.1009411811828613\n",
      "epoch:1 , loss:1.0960891246795654\n",
      "epoch:1 , loss:1.1040639877319336\n",
      "epoch:1 , loss:1.0937711000442505\n",
      "epoch:2 , loss:1.1071516275405884\n",
      "epoch:2 , loss:1.0898280143737793\n",
      "epoch:2 , loss:1.0876266956329346\n",
      "epoch:2 , loss:1.1013351678848267\n",
      "epoch:2 , loss:1.087397813796997\n",
      "epoch:2 , loss:1.0829508304595947\n",
      "epoch:2 , loss:1.1102570295333862\n",
      "epoch:2 , loss:1.1041171550750732\n",
      "epoch:2 , loss:1.107223391532898\n",
      "epoch:2 , loss:1.0949138402938843\n",
      "epoch:3 , loss:1.0810829401016235\n",
      "epoch:3 , loss:1.0818899869918823\n",
      "epoch:3 , loss:1.091726541519165\n",
      "epoch:3 , loss:1.0916234254837036\n",
      "epoch:3 , loss:1.0867525339126587\n",
      "epoch:3 , loss:1.099226951599121\n",
      "epoch:3 , loss:1.0857096910476685\n",
      "epoch:3 , loss:1.0940146446228027\n",
      "epoch:3 , loss:1.093887209892273\n",
      "epoch:3 , loss:1.1105033159255981\n",
      "epoch:4 , loss:1.1154266595840454\n",
      "epoch:4 , loss:1.0871334075927734\n",
      "epoch:4 , loss:1.0580840110778809\n",
      "epoch:4 , loss:1.1201266050338745\n",
      "epoch:4 , loss:1.1135869026184082\n",
      "epoch:4 , loss:1.0984891653060913\n",
      "epoch:4 , loss:1.0750253200531006\n",
      "epoch:4 , loss:1.0807818174362183\n",
      "epoch:4 , loss:1.0894041061401367\n",
      "epoch:4 , loss:1.118432879447937\n"
     ]
    }
   ],
   "source": [
    "for i in range(epochs):\n",
    "    for batch_idx, (x,y) in enumerate(batchdata):\n",
    "        #print(x,y)\n",
    "        y = y.view(x.shape[0]).long()\n",
    "        z1 = net.forward(x)\n",
    "        loss = criterion(z1,y)\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        #if batch_idx % 10 == 0:\n",
    "        print(\"epoch:{} , loss:{}\".format(i,loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于真实数据的神经网络模型构建\n",
    "\n",
    "torch.max(sigma,1)[1]\n",
    "\n",
    "torch.sum(r1 == y)\n",
    "\n",
    "x.view(-1, 28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = torchvision.datasets.FashionMNIST(root='H:\\data\\FashionMNIST', train=True, download=True, transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset FashionMNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: H:\\data\\FashionMNIST\n",
       "    Split: Train\n",
       "    StandardTransform\n",
       "Transform: ToTensor()"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x2dc3f0199a0>"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAilUlEQVR4nO3df3DU9b3v8dfm1xIg2RBCfknAgAoqEFsKMdVSlFwgnesF5fRq650DvY4eaXCK9IdDj4r2dE5anGO9tVTvndNCnSnaOlfkyLHcKjShtGALwqXWNgdoFCwk/KjZDQlJNtnP/YNrNArC+8smnyQ8HzM7Q3a/L74fvnyTV77Z3XdCzjknAAD6WYrvBQAALk0UEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAv0nwv4MMSiYSOHDmirKwshUIh38sBABg559TS0qLi4mKlpJz7OmfAFdCRI0dUUlLiexkAgIt0+PBhjR079pyPD7gCysrKkiTdqM8pTemeVwMAsOpSXNv1cs/X83PpswJas2aNHnvsMTU2NqqsrExPPvmkZs6ced7cez92S1O60kIUEAAMOv9/wuj5nkbpkxch/OxnP9OKFSu0atUqvf766yorK9O8efN07NixvtgdAGAQ6pMCevzxx3X33XfrS1/6kq655ho9/fTTGj58uH784x/3xe4AAINQ0guos7NTu3fvVmVl5fs7SUlRZWWlduzY8ZHtOzo6FIvFet0AAENf0gvoxIkT6u7uVkFBQa/7CwoK1NjY+JHta2pqFIlEem68Ag4ALg3e34i6cuVKRaPRntvhw4d9LwkA0A+S/iq4vLw8paamqqmpqdf9TU1NKiws/Mj24XBY4XA42csAAAxwSb8CysjI0PTp07Vly5ae+xKJhLZs2aKKiopk7w4AMEj1yfuAVqxYocWLF+tTn/qUZs6cqSeeeEKtra360pe+1Be7AwAMQn1SQLfffruOHz+uhx9+WI2Njbruuuu0efPmj7wwAQBw6Qo555zvRXxQLBZTJBLRbC1gEgIADEJdLq5abVQ0GlV2dvY5t/P+KjgAwKWJAgIAeEEBAQC8oIAAAF5QQAAALyggAIAXFBAAwAsKCADgBQUEAPCCAgIAeEEBAQC8oIAAAF5QQAAALyggAIAXFBAAwAsKCADgBQUEAPCCAgIAeEEBAQC8oIAAAF5QQAAALyggAIAXFBAAwAsKCADgBQUEAPCCAgIAeJHmewHAgBIK2TPOJX8dZ5E6OteceXfeVYH2lb1+Z6CcWYDjHUpLN2dcvNOcGfCCnKtB9dE5zhUQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADwggICAHjBMFLgA0KpqeaM6+oyZ1Kuu8ac+dM/jLTv57Q5IklKb51pzqSdTtj388td5ky/DhYNMiw1wDmkkP1aoD+PQyjNVhUh56QL+LTgCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADwggICAHhBAQEAvGAYKfAB1qGLUrBhpIfn5Zgzd1b82pz5zfEJ5owkvR0uNGdcpn0/aZUV5sxVP/yrOdP11iFzRpLknD0S4HwIInXUqGDB7m57JBYzbe/chR0DroAAAF5QQAAAL5JeQI888ohCoVCv2+TJk5O9GwDAINcnzwFde+21evXVV9/fSYCfqwMAhrY+aYa0tDQVFtqfxAQAXDr65Dmg/fv3q7i4WBMmTNCdd96pQ4fO/QqUjo4OxWKxXjcAwNCX9AIqLy/XunXrtHnzZj311FNqaGjQZz7zGbW0tJx1+5qaGkUikZ5bSUlJspcEABiAkl5AVVVV+vznP69p06Zp3rx5evnll9Xc3Kyf//znZ91+5cqVikajPbfDhw8ne0kAgAGoz18dkJOTo6uuukoHDhw46+PhcFjhcLivlwEAGGD6/H1Ap06d0sGDB1VUVNTXuwIADCJJL6Cvfe1rqqur01tvvaXf/va3uvXWW5WamqovfOELyd4VAGAQS/qP4N555x194Qtf0MmTJzVmzBjdeOON2rlzp8aMGZPsXQEABrGkF9Bzzz2X7L8S6DeJ9vZ+2U/nJ06ZM38X2WXODEuJmzOSVJeSMGf+utX+Ctbuafbj8PbjWeZMYs+nzRlJGv2GfXBn9p6j5syJWZeZM8en2welSlLBTntm1KsHTdu7RKd04vzbMQsOAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALzo819IB3gRCgXLOfuAx1P/9Xpz5u+vqTVnDsbtE+XHZvzNnJGkzxfvtof+mz3zg/rPmjOtf4mYMykjgg3ubLze/j36XxfY/59cvMucGfV6sC/fKYubzJlY5wTT9l3xdmnjBazFvBIAAJKAAgIAeEEBAQC8oIAAAF5QQAAALyggAIAXFBAAwAsKCADgBQUEAPCCAgIAeEEBAQC8oIAAAF5QQAAAL5iGjf4VdEr1AHb9A78zZ24a+WYfrOSjLlOwKdCtLsOcae4eYc6suubfzZnjV2WZM3EX7Evdv+7/tDlzKsC07tQu++fF9f99jzkjSYtyf2/OrP7fU03bd7n4BW3HFRAAwAsKCADgBQUEAPCCAgIAeEEBAQC8oIAAAF5QQAAALyggAIAXFBAAwAsKCADgBQUEAPCCAgIAeMEwUvQvF2w45kC2/1S+OXMye6Q509iVY86MTj1lzkhSVsppc+by9BPmzPFu+2DR1PSEOdPpUs0ZSXr02pfMmfar082Z9FC3OfPpYUfMGUn6/Jt/b86M0F8C7et8uAICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8YRgpcpDFh+8DPYaG4OZMR6jJnjsRHmTOStP/0JHPmP2L2oazzC/5ozsQDDBZNVbAhuEGGhBanv2vOtDv7AFP7GXTGDQX2waJ7A+7rfLgCAgB4QQEBALwwF9C2bdt0yy23qLi4WKFQSC+++GKvx51zevjhh1VUVKTMzExVVlZq//79yVovAGCIMBdQa2urysrKtGbNmrM+vnr1an3/+9/X008/rddee00jRozQvHnz1N7eftGLBQAMHeYXIVRVVamqquqsjznn9MQTT+jBBx/UggULJEnPPPOMCgoK9OKLL+qOO+64uNUCAIaMpD4H1NDQoMbGRlVWVvbcF4lEVF5erh07dpw109HRoVgs1usGABj6klpAjY2NkqSCgoJe9xcUFPQ89mE1NTWKRCI9t5KSkmQuCQAwQHl/FdzKlSsVjUZ7bocPH/a9JABAP0hqARUWFkqSmpqaet3f1NTU89iHhcNhZWdn97oBAIa+pBZQaWmpCgsLtWXLlp77YrGYXnvtNVVUVCRzVwCAQc78KrhTp07pwIEDPR83NDRo7969ys3N1bhx47R8+XJ9+9vf1pVXXqnS0lI99NBDKi4u1sKFC5O5bgDAIGcuoF27dummm27q+XjFihWSpMWLF2vdunX6xje+odbWVt1zzz1qbm7WjTfeqM2bN2vYsGHJWzUAYNALOeeCTenrI7FYTJFIRLO1QGkh+4A+DHChkD2Sah8+6brsgzslKXWUfXjnHTv+YN9PyP5pd7wry5zJSW0zZySprtk+jPSPJ8/+PO/H+dakfzNnXm+73JwpzrAPCJWCHb+3OvPMmSvDZ3+V8Mf5xbtl5owklQz7mznzy+WzTNt3dbVre+2jikajH/u8vvdXwQEALk0UEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4Yf51DMBFCTB8PZRmP02DTsM+fNfV5szNw18yZ37bfpk5MyatxZyJO/skcUkqCkfNmayCdnOmuXu4OZObdsqcaenONGckaXhKhzkT5P/pkxknzJn7X/2kOSNJWVNOmjPZ6bZrlcQFXttwBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXjCMFP0qlJ5hziTa7UMug8r7Q6c5c6I73ZzJSWkzZzJC3eZMZ8BhpJ/ObTBnjgcY+Pn66VJzJiv1tDkzJsU+IFSSStLtgzv/0F5izrzceoU5c9d/ftWckaRn/9d/MmcyNv/WtH2Ki1/YduaVAACQBBQQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADw4tIeRhoKBYul2YdPhlIDdH2KPZNo77DvJ2EfchmUi9uHffan//E/f2DOHO7KMWca4/ZMTqp9gGm3gp3jO09HzJlhKRc2gPKDxqTFzJlYwj70NKiWxDBzJh5gAGyQY/fA6P3mjCS9EK0MlOsLXAEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBdDZhhpKM3+T3FdXYH2FWSgprPPGhySTi+Yac4cXmgflnrnJ35nzkhSY1eWObOn7XJzJpJ62pwZkWIfNNvu7INzJelI5yhzJshAzdy0U+ZMfoABpt0u2Pfaf43bj0MQQQbNvtNlP3aS1PJfWsyZnGcC7eq8uAICAHhBAQEAvDAX0LZt23TLLbeouLhYoVBIL774Yq/HlyxZolAo1Os2f/78ZK0XADBEmAuotbVVZWVlWrNmzTm3mT9/vo4ePdpze/bZZy9qkQCAocf8zH1VVZWqqqo+dptwOKzCwsLAiwIADH198hxQbW2t8vPzNWnSJC1dulQnT54857YdHR2KxWK9bgCAoS/pBTR//nw988wz2rJli7773e+qrq5OVVVV6u4++0tpa2pqFIlEem4lJSXJXhIAYABK+vuA7rjjjp4/T506VdOmTdPEiRNVW1urOXPmfGT7lStXasWKFT0fx2IxSggALgF9/jLsCRMmKC8vTwcOHDjr4+FwWNnZ2b1uAIChr88L6J133tHJkydVVFTU17sCAAwi5h/BnTp1qtfVTENDg/bu3avc3Fzl5ubq0Ucf1aJFi1RYWKiDBw/qG9/4hq644grNmzcvqQsHAAxu5gLatWuXbrrppp6P33v+ZvHixXrqqae0b98+/eQnP1Fzc7OKi4s1d+5c/dM//ZPC4XDyVg0AGPRCzjnnexEfFIvFFIlENFsLlBYKNkhxIEorsr8vKl5aYM787erh5kxbYcickaTrPvcnc2ZJwXZz5ni3/XnB9FCwQbMt3ZnmTGF6szmzNXqNOTMyzT6MNMjQU0n6ZOZb5kxzwn7uFae9a848cODvzJmC4fYBnJL0r+NfNmfiLmHO1Mft36BnpdiHIkvSr9uuMGc2XDPGtH2Xi6tWGxWNRj/2eX1mwQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMCLpP9Kbl86qmaYM/n/+JdA+7ou+x1z5ppM+xTo9oR9GviwlLg58+bpy8wZSWpLZJgz+zvtU8GjXfYpy6kh+0RiSTrWmWXO/EtDpTmzZebT5syDR+abMymZwYbdn+weac4sGhkLsCf7Of4P47aZMxMyjpkzkrSp1f6LNI/ER5kzBelRc+by9OPmjCTdlvUf5swG2aZhXyiugAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADAiwE7jDSUlqZQ6MKXV/7PvzfvY07WH80ZSWpzYXMmyGDRIEMNg4iktQXKdcTtp8+xeHagfVldFW4MlLs1e685s+0H5ebMje33mTMHb15rzmw5nWrOSNLxLvv/0x0NN5szrx8qMWeuv7zBnJma9VdzRgo2CDcrtd2cSQ91mTOtCfvXIUna2W4fNNtXuAICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8G7DDSo0unKzU87IK3fyTypHkf6/92vTkjSSXD/mbOjM84Yc6UZb5tzgSRlWIfnihJk7LtAxQ3tY41Z2qbJ5szRenN5owk/bptojnz3COPmTNL7v+qOVPx8r3mTOzyYN9jdo1w5kx22Ulz5sFP/Ls5kxHqNmeau+1DRSUpN9xqzuSkBhvuaxVkKLIkZaWcNmdSJ11h2t51d0j7z78dV0AAAC8oIACAFxQQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMALCggA4MWAHUY6/FhCqRmJC95+U+w68z4mZB43ZyTpRDzLnPk/p6aaM2Mz3zVnIqn2QYNXhBvNGUna255jzmw+fq05U5wZM2ea4hFzRpJOxkeYM20J+1DIH33vcXPmX5oqzZlbc183ZySpLMM+WLQ5Yf9+9s3OQnOmJXHhQ4rf0+7SzRlJigYYYpoV4HMw7uxfilPdhX99/KCcFPuw1NjU0abtu+LtDCMFAAxcFBAAwAtTAdXU1GjGjBnKyspSfn6+Fi5cqPr6+l7btLe3q7q6WqNHj9bIkSO1aNEiNTU1JXXRAIDBz1RAdXV1qq6u1s6dO/XKK68oHo9r7ty5am19/5c23X///XrppZf0/PPPq66uTkeOHNFtt92W9IUDAAY30zNfmzdv7vXxunXrlJ+fr927d2vWrFmKRqP60Y9+pPXr1+vmm2+WJK1du1ZXX321du7cqeuvD/YbSAEAQ89FPQcUjUYlSbm5uZKk3bt3Kx6Pq7Ly/VfrTJ48WePGjdOOHTvO+nd0dHQoFov1ugEAhr7ABZRIJLR8+XLdcMMNmjJliiSpsbFRGRkZysnJ6bVtQUGBGhvP/lLfmpoaRSKRnltJSUnQJQEABpHABVRdXa033nhDzz333EUtYOXKlYpGoz23w4cPX9TfBwAYHAK9EXXZsmXatGmTtm3bprFjx/bcX1hYqM7OTjU3N/e6CmpqalJh4dnfcBYOhxUO29/IBwAY3ExXQM45LVu2TBs2bNDWrVtVWlra6/Hp06crPT1dW7Zs6bmvvr5ehw4dUkVFRXJWDAAYEkxXQNXV1Vq/fr02btyorKysnud1IpGIMjMzFYlEdNddd2nFihXKzc1Vdna27rvvPlVUVPAKOABAL6YCeuqppyRJs2fP7nX/2rVrtWTJEknS9773PaWkpGjRokXq6OjQvHnz9MMf/jApiwUADB0h55zzvYgPisViikQimnXjQ0pLu/ChgzOe2G3e1xuxYnNGkgqGtZgz00a+Y87Ut9kHNR45nW3ODE+LmzOSlJlqz3U5++te8sP24z0ubB+mKUlZKfZBkhmhbnOmO8Drf67NOGLOHOoaZc5IUmNXjjnzZpv982lUmn0w5h8CfN62dWWYM5LU0W1/mry9y56JhNvNmRm5b5szkpQi+5f89f/2WdP2ifZ2/eXb/6hoNKrs7HN/TWIWHADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALwI9BtR+0PK9n1KCaVf8PbP//IG8z4eWvC8OSNJdc2TzZlNjVPNmVin/TfFjhneas5kp9unTUtSbrp9X5EA04+HhbrMmXe7RpgzktSRcuHn3Hu6FTJnGjsi5sxvEleaM/FEqjkjSR0BckGmo/+tM8+cKc6MmjMtXRc+Wf+D3mrJNWdOREeaM+3D7V+Kt3dPNGckaX7hH82ZzGO2c7y748K25woIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALwIOeec70V8UCwWUyQS0WwtUJphGGkQ0TuvD5Sb8OV6c2ZmToM583psnDlzKMDwxHgi2Pch6SkJc2Z4eqc5MyzAkMuM1G5zRpJSZP90SAQYRjoi1X4cRqR1mDPZae3mjCRlpdpzKSH7+RBEaoD/o99FL0/+Qs4hK8D/U5ezfw5WRA6aM5L044ZPmzORzx0wbd/l4qrVRkWjUWVnZ59zO66AAABeUEAAAC8oIACAFxQQAMALCggA4AUFBADwggICAHhBAQEAvKCAAABeUEAAAC8oIACAFxQQAMCLgTuMNOU22zDSRLDhk/2ldVG5OVP+zd/bM1n2AYWTM5rMGUlKl3345LAAAytHpNiHfbYHPK2DfEe2/XSJOdMdYE9b373anIkHGHIpSU1t5x4geS7pAQfAWiWc/Xw43RVssHH09DBzJjXFfu611+aZM6PftA/plaTwy/avK1YMIwUADGgUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8GLgDiPVAtswUgQWmjE1UO50YaY5Ez7ZYc60jLfvJ/tgqzkjSSkdXeZM4v/+KdC+gKGKYaQAgAGNAgIAeGEqoJqaGs2YMUNZWVnKz8/XwoULVV9f32ub2bNnKxQK9brde++9SV00AGDwMxVQXV2dqqurtXPnTr3yyiuKx+OaO3euWlt7/7z97rvv1tGjR3tuq1evTuqiAQCDX5pl482bN/f6eN26dcrPz9fu3bs1a9asnvuHDx+uwsLC5KwQADAkXdRzQNFoVJKUm5vb6/6f/vSnysvL05QpU7Ry5Uq1tbWd8+/o6OhQLBbrdQMADH2mK6APSiQSWr58uW644QZNmTKl5/4vfvGLGj9+vIqLi7Vv3z498MADqq+v1wsvvHDWv6empkaPPvpo0GUAAAapwO8DWrp0qX7xi19o+/btGjt27Dm327p1q+bMmaMDBw5o4sSJH3m8o6NDHR3vvzckFouppKSE9wH1I94H9D7eBwRcvAt9H1CgK6Bly5Zp06ZN2rZt28eWjySVl5dL0jkLKBwOKxwOB1kGAGAQMxWQc0733XefNmzYoNraWpWWlp43s3fvXklSUVFRoAUCAIYmUwFVV1dr/fr12rhxo7KystTY2ChJikQiyszM1MGDB7V+/Xp97nOf0+jRo7Vv3z7df//9mjVrlqZNm9Yn/wAAwOBkKqCnnnpK0pk3m37Q2rVrtWTJEmVkZOjVV1/VE088odbWVpWUlGjRokV68MEHk7ZgAMDQYP4R3McpKSlRXV3dRS0IAHBpCPwybAwd7vd/CJQbluR1nEv2b/tpR5IS/bcr4JLHMFIAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAvKCAAgBcUEADACwoIAOAFBQQA8IICAgB4QQEBALyggAAAXlBAAAAv0nwv4MOcc5KkLsUl53kxAACzLsUlvf/1/FwGXAG1tLRIkrbrZc8rAQBcjJaWFkUikXM+HnLnq6h+lkgkdOTIEWVlZSkUCvV6LBaLqaSkRIcPH1Z2dranFfrHcTiD43AGx+EMjsMZA+E4OOfU0tKi4uJipaSc+5meAXcFlJKSorFjx37sNtnZ2Zf0CfYejsMZHIczOA5ncBzO8H0cPu7K5z28CAEA4AUFBADwYlAVUDgc1qpVqxQOh30vxSuOwxkchzM4DmdwHM4YTMdhwL0IAQBwaRhUV0AAgKGDAgIAeEEBAQC8oIAAAF4MmgJas2aNLr/8cg0bNkzl5eX63e9+53tJ/e6RRx5RKBTqdZs8ebLvZfW5bdu26ZZbblFxcbFCoZBefPHFXo875/Twww+rqKhImZmZqqys1P79+/0stg+d7zgsWbLkI+fH/Pnz/Sy2j9TU1GjGjBnKyspSfn6+Fi5cqPr6+l7btLe3q7q6WqNHj9bIkSO1aNEiNTU1eVpx37iQ4zB79uyPnA/33nuvpxWf3aAooJ/97GdasWKFVq1apddff11lZWWaN2+ejh075ntp/e7aa6/V0aNHe27bt2/3vaQ+19raqrKyMq1Zs+asj69evVrf//739fTTT+u1117TiBEjNG/ePLW3t/fzSvvW+Y6DJM2fP7/X+fHss8/24wr7Xl1dnaqrq7Vz50698sorisfjmjt3rlpbW3u2uf/++/XSSy/p+eefV11dnY4cOaLbbrvN46qT70KOgyTdfffdvc6H1atXe1rxObhBYObMma66urrn4+7ubldcXOxqamo8rqr/rVq1ypWVlflehleS3IYNG3o+TiQSrrCw0D322GM99zU3N7twOOyeffZZDyvsHx8+Ds45t3jxYrdgwQIv6/Hl2LFjTpKrq6tzzp35v09PT3fPP/98zzZ/+tOfnCS3Y8cOX8vscx8+Ds4599nPftZ95Stf8beoCzDgr4A6Ozu1e/duVVZW9tyXkpKiyspK7dixw+PK/Ni/f7+Ki4s1YcIE3XnnnTp06JDvJXnV0NCgxsbGXudHJBJReXn5JXl+1NbWKj8/X5MmTdLSpUt18uRJ30vqU9FoVJKUm5srSdq9e7fi8Xiv82Hy5MkaN27ckD4fPnwc3vPTn/5UeXl5mjJlilauXKm2tjYfyzunATeM9MNOnDih7u5uFRQU9Lq/oKBAf/7znz2tyo/y8nKtW7dOkyZN0tGjR/Xoo4/qM5/5jN544w1lZWX5Xp4XjY2NknTW8+O9xy4V8+fP12233abS0lIdPHhQ3/zmN1VVVaUdO3YoNTXV9/KSLpFIaPny5brhhhs0ZcoUSWfOh4yMDOXk5PTadiifD2c7DpL0xS9+UePHj1dxcbH27dunBx54QPX19XrhhRc8rra3AV9AeF9VVVXPn6dNm6by8nKNHz9eP//5z3XXXXd5XBkGgjvuuKPnz1OnTtW0adM0ceJE1dbWas6cOR5X1jeqq6v1xhtvXBLPg36ccx2He+65p+fPU6dOVVFRkebMmaODBw9q4sSJ/b3MsxrwP4LLy8tTamrqR17F0tTUpMLCQk+rGhhycnJ01VVX6cCBA76X4s175wDnx0dNmDBBeXl5Q/L8WLZsmTZt2qRf/epXvX59S2FhoTo7O9Xc3Nxr+6F6PpzrOJxNeXm5JA2o82HAF1BGRoamT5+uLVu29NyXSCS0ZcsWVVRUeFyZf6dOndLBgwdVVFTkeynelJaWqrCwsNf5EYvF9Nprr13y58c777yjkydPDqnzwzmnZcuWacOGDdq6datKS0t7PT59+nSlp6f3Oh/q6+t16NChIXU+nO84nM3evXslaWCdD75fBXEhnnvuORcOh926devcm2++6e655x6Xk5PjGhsbfS+tX331q191tbW1rqGhwf3mN79xlZWVLi8vzx07dsz30vpUS0uL27Nnj9uzZ4+T5B5//HG3Z88e9/bbbzvnnPvOd77jcnJy3MaNG92+ffvcggULXGlpqTt9+rTnlSfXxx2HlpYW97Wvfc3t2LHDNTQ0uFdffdV98pOfdFdeeaVrb2/3vfSkWbp0qYtEIq62ttYdPXq059bW1tazzb333uvGjRvntm7d6nbt2uUqKipcRUWFx1Un3/mOw4EDB9y3vvUtt2vXLtfQ0OA2btzoJkyY4GbNmuV55b0NigJyzrknn3zSjRs3zmVkZLiZM2e6nTt3+l5Sv7v99ttdUVGRy8jIcJdddpm7/fbb3YEDB3wvq8/96le/cpI+clu8eLFz7sxLsR966CFXUFDgwuGwmzNnjquvr/e76D7wccehra3NzZ07140ZM8alp6e78ePHu7vvvnvIfZN2tn+/JLd27dqebU6fPu2+/OUvu1GjRrnhw4e7W2+91R09etTfovvA+Y7DoUOH3KxZs1xubq4Lh8PuiiuucF//+tddNBr1u/AP4dcxAAC8GPDPAQEAhiYKCADgBQUEAPCCAgIAeEEBAQC8oIAAAF5QQAAALyggAIAXFBAAwAsKCADgBQUEAPCCAgIAePH/AIe0yFA5VNd3AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(mnist[0][0].view((28, 28)).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([60000, 28, 28])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist.targets\n",
    "mnist.data.shape\n",
    "#mnist[0][0].view(-1, 28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义神经网路的架构\n",
    "class Model_mnist(nn.Module):\n",
    "    def __init__(self,in_features=10,out_features=2):\n",
    "        super().__init__() \n",
    "        self.linear1 = nn.Linear(in_features,128)\n",
    "        self.output = nn.Linear(128,out_features)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 28*28)\n",
    "        sigma1 = torch.relu(self.linear1(x))\n",
    "        z2 = self.output(sigma1)\n",
    "        return z2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28, 28])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist.data[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "784"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "28*28"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "#mnist.data[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
